import numpy as np
import random
from random import choice
from scipy.stats import bernoulli
from scipy.stats import norm
import matplotlib.pyplot as plt
from tqdm import tqdm
import math

EPS = 10E-12

T = 50000
K = 16
V = 1
C = [5000] * 8
d = 4
target = K - 2

num = 16

class Bandit:
    def __init__(self, K, d):
        self.K = K
        self.U = np.zeros(K)
        self.U = np.linspace(0.02, 0.96, K)[::-1]
        self.best_U = 0
        for i in range(d):
            self.best_U += self.U[i]

    def generate_award(self):
        self.generate_U = np.zeros(self.K)
        for i in range(self.K):
            self.generate_U[i] = self.U[i] + np.random.normal(0, 0.1)
            self.generate_U[i] = min(self.generate_U[i], 1)
            self.generate_U[i] = max(self.generate_U[i], 0)
            # self.generate_U[i] = bernoulli.rvs(self.U[i])
        return self.generate_U

    def generate_regret(self, action_set):
        r = self.best_U
        for i in action_set:
            r -= self.U[i]
        return r

class Meta_Algorithm:
    def __init__(self, K, d):
        self.K = K
        self.d = d
        self.unconstrained = False

    def next(self):
        raise NotImplementedError

    def reset(self) -> object:
        raise NotImplementedError

    def update(self, action, feedback):
        raise NotImplementedError

    def sample_action(self, x):
        order = np.argsort(-x)
        included = np.copy(x[order])
        remaining = 1.0 - included
        outer_samples = [w for w in self.split_sample(included, remaining)]
        weights = list(map(lambda z: z[0], outer_samples))
        _, left, right = outer_samples[np.random.choice(len(outer_samples), p=weights)]
        if left == right - 1:
            sample = range(self.d)
        else:
            candidates = [i for i in range(left, right)]
            random.shuffle(candidates)
            sample = [i for i in range(left)] + candidates[:self.d - left]
        action = [order[i] for i in sample]
        return action

    def split_sample(self, included, remaining):
        prop = 1.0
        left, right = 0, self.K
        i = self.K
        while left < right:
            i -= 1
            active = (self.d - left) / (right - left)
            inactive = 1.0 - active
            if active == 0 or inactive == 0:
                yield (prop, left, right)
                return
            weight = min(included[right - 1] / active, remaining[left] / inactive)
            yield weight, left, right
            prop -= weight
            assert prop >= -EPS
            included -= weight * active
            remaining -= weight * inactive
            while right > 0 and included[right - 1] <= EPS:
                right -= 1
            while left < self.K and remaining[left] <= EPS:
                left += 1
            assert right - left <= i
        if prop > 0.0:
            yield (prop, self.d, self.d + 1)

class OSMD(Meta_Algorithm):
    L = None
    x = None
    time_step = 0.0
    gamma = 1.0
    learning_rate = None
    bias = None

    def __init__(self, K, d):
        super().__init__(K, d)
        self.x = np.zeros(K)

    def next(self, ):
        self.time_step += 1.0
        self.learning_rate = self.get_learning_rate(self.time_step)
        self.solve_optimization()
        return self.sample_action(self.x)

    def reset(self):
        self.L = np.array([0.0] * self.K)
        self.x = [self.d / self.K if not self.unconstrained else 0.5 for _ in range(self.K)]
        self.time_step = 0.0
        self.bias = 0

    def solve_optimization(self):
        if self.unconstrained:
            self.x = np.array([self.solve_unconstrained(l * self.learning_rate, x) for l, x in zip(self.L, self.x)])
        else:
            max_iter = 1000
            iteration = 0
            upper = None
            lower = None
            step_size = 1
            while True:
                iteration += 1
                self.index = np.arange(0, self.K, 1)
                self.x = np.array(
                    [self.solve_unconstrained((l + self.bias) * self.learning_rate, x, i) for l, x, i in zip(self.L, self.x, self.index)])
                f = self.x.sum() - self.d
                df = self.hessian_inverse()
                next_bias = self.bias + f / df
                if f > 0:
                    lower = self.bias
                    self.bias = next_bias
                    if upper is None:
                        step_size *= 2
                        if next_bias > lower + step_size:
                            self.bias = lower + step_size
                    else:
                        if next_bias > upper:
                            self.bias = (lower + upper) / 2
                else:
                    upper = self.bias
                    self.bias = next_bias
                    if lower is None:
                        step_size *= 2
                        if next_bias < upper - step_size:
                            self.bias = upper - step_size
                    else:
                        if next_bias < lower:
                            self.bias = (lower + upper) / 2

                if iteration > max_iter or abs(f) < 100 * EPS:
                    break

            # assert iteration < max_iter

    def get_learning_rate(self, time):
        return 1.0 / np.sqrt(time)

    def solve_unconstrained(self, loss, warmstart):
        raise NotImplementedError

    def hessian_inverse(self):
        raise NotImplementedError


class HYBRID(OSMD):
    def __init__(self, K, d):
        super().__init__(K, d)
        if d is None or d < K / 2:
            self.gamma = 1.0
        else:
            self.gamma = np.sqrt(1.0 / np.log(K - (K - d)))

    def solve_unconstrained(self, loss, warmstart, i):
        x_val, func_val, dif_func_val, dif_x = warmstart, 1.0, float('inf'), 1.0

        while True:
            func_val = loss - 0.5 / np.sqrt(x_val) + self.gamma * (1.0 - np.log(1.0 - x_val))
            dif_func_val = 0.25 / (np.sqrt(x_val) ** 3) + self.gamma / (1.0 - x_val)
            dif_x = func_val / dif_func_val
            if dif_x > x_val:
                dif_x = x_val / 2
            elif dif_x < x_val - 1.0:
                dif_x = (x_val - 1.0) / 2
            if abs(dif_x) < EPS:
                break
            x_val -= dif_x
        return x_val

    def hessian_inverse(self):
        return (1.0 / (0.25 / np.power(self.x, 1.5) + self.gamma / (1.0 - self.x))).sum() * self.learning_rate

    def update(self, action, feedback):
        if len(action):
            self.L[action] += np.divide((np.array(feedback) + 1.0), self.x[action])
        if self.unconstrained:
            self.L -= 1
        else:
            self.L += self.bias
            self.bias = 0

class LBINF(OSMD):
    def __init__(self, K, d, T):
        super().__init__(K, d)
        self.gamma = np.log(T)
        self.beta = np.array([np.sqrt(2)] * K)
        self.m = np.array([1. / 4.] * K)
        self.alpha = np.zeros(K)

    def solve_unconstrained(self, loss, warmstart, i):
        x_val, func_val, dif_func_val, dif_x = warmstart, 1.0, float('inf'), 1.0

        while True:
            func_val = (loss + self.m[i] - self.beta[i] * (1.0 / x_val + self.gamma * (np.log(1.0 - x_val) + 1.0)))
            dif_func_val = (self.beta[i] / (x_val ** 2) + self.beta[i] * self.gamma / (1.0 - x_val))
            dif_x = func_val / dif_func_val

            if dif_x > x_val:
                dif_x = x_val / 2
            elif dif_x < x_val - 1.0:
                dif_x = (x_val - 1.0) / 2
            if abs(dif_x) < EPS:
                break
            x_val -= dif_x
        return x_val

    def hessian_inverse(self):
        diag_hess = (self.beta / (self.x ** 2) + self.beta * self.gamma / (1.0 - self.x))
        inv_diag = 1.0 / (diag_hess)
        return inv_diag.sum()

    def update(self, action, feedback):
        if len(action):
            self.L[action] += np.divide((np.array(feedback) + self.m[action]), self.x[action])
        if self.unconstrained:
            self.L -= 1
        else:
            self.L += self.bias
            self.bias = 0
        for i in range(len(action)):

            alpha = ((feedback[i] - self.m[action[i]]) ** 2) * min(1, (1 - self.x[action[i]]) / (
                        self.gamma * self.x[i] * self.x[i]))
            self.m[action[i]] += (feedback[i] - self.m[action[i]]) / 4
            self.beta[action[i]] = np.sqrt(self.beta[action[i]] ** 2 + alpha / np.log(T))

    def get_learning_rate(self, time):
        return 1.0

class LBINFLS(OSMD):
    def __init__(self, K, d, T):
        super().__init__(K, d)
        self.gamma = np.log(T)
        self.beta = np.array([np.sqrt(2)] * K)
        self.m = np.array([1. / 4.] * K)
        self.alpha = np.zeros(K)
        self.n = np.zeros(K)
        self.l = np.zeros(K)

    def solve_unconstrained(self, loss, warmstart, i):
        x_val, func_val, dif_func_val, dif_x = warmstart, 1.0, float('inf'), 1.0

        while True:
            func_val = (loss + self.m[i] + self.beta[i] * (1 - 1.0 / x_val - self.gamma * (np.log(1.0 - x_val))))
            dif_func_val = (self.beta[i] / (x_val ** 2) + self.beta[i] * self.gamma / (1.0 - x_val))
            dif_x = func_val / (dif_func_val)

            if dif_x > x_val:
                dif_x = x_val / 2
            elif dif_x < x_val - 1.0:
                dif_x = (x_val - 1.0) / 2
            if abs(dif_x) < EPS:
                break
            x_val -= dif_x
        return x_val

    def hessian_inverse(self):
        diag_hess = (self.beta / (self.x ** 2) + self.beta * self.gamma / (1.0 - self.x))
        inv_diag = 1.0 / (diag_hess)
        return inv_diag.sum()

    def update(self, action, feedback):
        if len(action):
            self.L[action] += np.divide((np.array(feedback) + self.m[action]), self.x[action])
        if self.unconstrained:
            self.L -= 1
        else:
            self.L += self.bias
            self.bias = 0
        for i in range(len(action)):
            self.l[action[i]] += feedback[i]
            self.n[action[i]] += 1
            alpha = ((feedback[i] - self.m[action[i]]) ** 2) * min(1, 2 * (1 - self.x[action[i]]) / (
                    self.gamma * self.x[i] * self.x[i]))
            self.m[action[i]] = (0.5 + self.l[action[i]]) / (1 + self.n[action[i]])
            self.beta[action[i]] = np.sqrt(self.beta[action[i]] ** 2 + alpha / np.log(T))

    def get_learning_rate(self, time):
        return 1.0

class LBINFGD(OSMD):
    def __init__(self, K, d, T):
        super().__init__(K, d)
        self.gamma = np.log(T)
        self.beta = np.array([np.sqrt(2)] * K)
        self.m = np.array([1. / 4.] * K)
        self.alpha = np.zeros(K)
        self.n = np.zeros(K)

    def solve_unconstrained(self, loss, warmstart, i):
        x_val, func_val, dif_func_val, dif_x = warmstart, 1.0, float('inf'), 1.0

        while True:
            func_val = (loss + self.m[i] + self.beta[i] * (1 - 1.0 / x_val - self.gamma * (np.log(1.0 - x_val))))
            dif_func_val = (self.beta[i] / (x_val ** 2) + self.beta[i] * self.gamma / (1.0 - x_val))
            dif_x = func_val / (dif_func_val)

            if dif_x > x_val:
                dif_x = x_val / 2
            elif dif_x < x_val - 1.0:
                dif_x = (x_val - 1.0) / 2
            if abs(dif_x) < EPS:
                break
            x_val -= dif_x
        return x_val

    def hessian_inverse(self):
        diag_hess = (self.beta / (self.x ** 2) + self.beta * self.gamma / (1.0 - self.x))
        inv_diag = 1.0 / (diag_hess)
        return inv_diag.sum()

    def update(self, action, feedback):
        if len(action):
            self.L[action] += np.divide((np.array(feedback) + self.m[action]), self.x[action])
        if self.unconstrained:
            self.L -= 1
        else:
            self.L += self.bias
            self.bias = 0
        for i in range(len(action)):
            self.n[action[i]] += 1
            alpha = ((feedback[i] - self.m[action[i]]) ** 2) * min(1, 2 * (1 - self.x[action[i]]) / (
                    self.gamma * self.x[i] * self.x[i]))
            self.m[action[i]] = (1 - 0.25) * self.m[action[i]] + feedback[i]
            self.beta[action[i]] = np.sqrt(self.beta[action[i]] ** 2 + alpha / np.log(T))

    def get_learning_rate(self, time):
        return 1.0

class DSBARBAT(Meta_Algorithm):
    def __init__(self, K, d):
        super().__init__(K, d)
        self.action_list_probability = np.zeros(self.K - self.d + 1)
        self.action_list = []

    def generate_probability(self, N_m , n_m, actions):
        self.action_list = actions
        self.action_list_probability = n_m / N_m
        self.action_list_probability[0] += 1 - self.action_list_probability.sum()
        return self.action_list_probability * N_m

    def next(self):
        action_set = np.random.multinomial(1, self.action_list_probability, size=1)
        for action in range(self.K):
            if action_set[0][action] == 1:
                return self.action_list[action]

HY_Agents = []
for v in range(V):
    HY_Agents.append(HYBRID(K, d))
    HY_Agents[v].reset()
HY_Regret = np.zeros(T)

DS_m = 1
DS_delta_m = ((DS_m + 4) * 2 ** (DS_m + 4)) * np.log(K)
DS_lambda_m = num * np.log(4 * K * DS_delta_m) / V
DS_Delta_m = np.ones(K)
DS_communication = []
DS_N_m = 0
DS_N_m += int(K * DS_lambda_m * np.power(2, 2 * DS_m - 2) + 1)
DS_communication.append(DS_N_m)
DS_n_m = np.zeros(K - d + 1)
DS_action_list = []
base_action = []
for k in range(d - 1):
    base_action.append(k)
for k in range(K - d + 1):
    DS_action_list.append(np.hstack((base_action, [k + d - 1])))
    DS_n_m[k] = DS_lambda_m * np.power(DS_Delta_m[k], -2)
DS_award_m = np.zeros([V, K])
DS_k_m = 0
DS_Agents = []
for v in range(V):
    DS_Agents.append(DSBARBAT(K, d))
    DS_Agents[v].generate_probability(DS_N_m, DS_n_m, DS_action_list)
DS_n_m = DS_Agents[0].generate_probability(DS_N_m, DS_n_m, DS_action_list)
DS_Regret = np.zeros(T)

LBINF_Agents = []
for v in range(V):
    LBINF_Agents.append(LBINF(K, d, T))
    LBINF_Agents[v].reset()
LBINF_Regret = np.zeros(T)

LS_Agents = []
for v in range(V):
    LS_Agents.append(LBINFLS(K, d, T))
    LS_Agents[v].reset()
LS_Regret = np.zeros(T)

GD_Agents = []
for v in range(V):
    GD_Agents.append(LBINFGD(K, d, T))
    GD_Agents[v].reset()
GD_Regret = np.zeros(T)

bandit = Bandit(K, d)

for t in (tqdm(range(T))):
    if t > 0:
        HY_Regret[t] = HY_Regret[t - 1]
        DS_Regret[t] = DS_Regret[t - 1]
        LBINF_Regret[t] = LBINF_Regret[t - 1]
        LS_Regret[t] = LS_Regret[t - 1]
        GD_Regret[t] = GD_Regret[t - 1]
    for v in range(V):
        concurrent_award = bandit.generate_award()

        action_set = HY_Agents[v].next()
        HY_Regret[t] += bandit.generate_regret(action_set)
        feedback = []
        for i in action_set:
            if C[1] > 0:
                if i < target:
                    feedback.append(1)
                else:
                    feedback.append(0)
            else:
                feedback.append(1 - concurrent_award[i])
        C[1] -= 1
        HY_Agents[v].update(action_set, feedback)

        action_set = DS_Agents[v].next()
        DS_Regret[t] += bandit.generate_regret(action_set)
        for i in action_set:
            if C[4] > 0:
                if i < target:
                    DS_award_m[v][i] += 0
                else:
                    DS_award_m[v][i] += 1
            else:
                DS_award_m[v][i] += concurrent_award[i]
        C[4] -= 1

        action_set = LBINF_Agents[v].next()
        LBINF_Regret[t] += bandit.generate_regret(action_set)
        feedback = []
        for i in action_set:
            if C[5] > 0:
                if i < target:
                    feedback.append(1)
                else:
                    feedback.append(0)
            else:
                feedback.append(1 - concurrent_award[i])
        C[5] -= 1
        LBINF_Agents[v].update(action_set, feedback)

        action_set = LS_Agents[v].next()
        LS_Regret[t] += bandit.generate_regret(action_set)
        feedback = []
        for i in action_set:
            if C[6] > 0:
                if i < target:
                    feedback.append(1)
                else:
                    feedback.append(0)
            else:
                feedback.append(1 - concurrent_award[i])
        C[6] -= 1
        LS_Agents[v].update(action_set, feedback)

        action_set = GD_Agents[v].next()
        GD_Regret[t] += bandit.generate_regret(action_set)
        feedback = []
        for i in action_set:
            if C[7] > 0:
                if i < target:
                    feedback.append(1)
                else:
                    feedback.append(0)
            else:
                feedback.append(1 - concurrent_award[i])
        C[7] -= 1
        GD_Agents[v].update(action_set, feedback)

    if t == DS_communication[-1]:
        DS_award_sum = np.zeros(K)
        for v in range(V):
            DS_award_sum += DS_award_m[v]
        DS_r_k = np.zeros(K)
        DS_r_star = 0
        list_a = []
        for k in base_action:
            DS_r_k[k] = min(DS_award_sum[k] / (V * DS_N_m), 1)
            list_a.append(DS_r_k[k] - np.sqrt(DS_lambda_m / (V * DS_N_m)) / 8)
        for k in range(K - d - 1):
            i = DS_action_list[k][d-1]
            DS_r_k[i] = min(DS_award_sum[i] / (V * DS_n_m[k]), 1)
            list_a.append(DS_r_k[i] - np.sqrt(DS_lambda_m / (V * DS_n_m[k])) / 8)
        list_a.sort(reverse=True)
        for k in range(d):
            DS_r_star += list_a[k]
        arg_r_k = np.argsort(-DS_r_k)
        base_action = []
        DS_action_list = []
        for k in range(d - 1):
            base_action.append(arg_r_k[k])
        for k in range(K - d + 1):
            i = arg_r_k[d + k - 1]
            DS_action_list.append(np.hstack((base_action, [i])))
            DS_Delta_m[k] = max(DS_r_star - DS_r_k[i], 2 ** (0 - DS_m))
        DS_m += 1
        DS_delta_m = K * ((DS_m + 4) * 2 ** (DS_m + 4)) * np.log(K)
        DS_lambda_m = num * np.log(4 * K * DS_delta_m) / V
        DS_N_m = int(K * DS_lambda_m * np.power(2, 2 * DS_m - 2) + 1)
        DS_communication.append(DS_N_m + t)
        DS_award_m = np.zeros([V, K])
        for k in range(K - d - 1):
            DS_n_m[k] = DS_lambda_m * np.power(DS_Delta_m[k], -2)
        for v in range(V):
            DS_Agents[v].generate_probability(DS_N_m, DS_n_m, DS_action_list)
        DS_n_m = DS_Agents[0].generate_probability(DS_N_m, DS_n_m, DS_action_list)

X = np.arange(1, T + 1)
plt.xticks(fontsize= 16)
plt.yticks(fontsize= 16)
plt.rcParams['font.size'] = 16

plt.plot(X, HY_Regret, 'm-', label= 'HYBRID')
plt.plot(X, DS_Regret, 'r-', label= 'DS-BARBAT')
plt.plot(X, LBINF_Regret, 'k-', label= 'LBINF')
plt.plot(X, LS_Regret, 'b-', label= 'LBINFLS')
plt.plot(X, GD_Regret, 'g-', label= 'LBINFGD')


plt.xlabel('Rounds')
plt.ylabel('Regret')
plt.grid(True)
plt.legend()
plt.ticklabel_format(style= 'sci', scilimits= (0, 0), axis= 'x')
plt.ticklabel_format(style= 'sci', scilimits= (0, 0), axis= 'y')
plt.show()